home *** CD-ROM | disk | FTP | other *** search
/ MacWorld 2005 March / Macworld CD March 2005 - Marathon Trilogy.iso / Shareware World / iPod / iPodderX.sit / iPodderX / iPodderX.app / Contents / Resources / RawServer.py < prev    next >
Encoding:
Python Source  |  2005-01-07  |  17.9 KB  |  624 lines

  1. # Written by Bram Cohen
  2. # see LICENSE.txt for license information
  3.  
  4. from bisect import insort
  5. import socket
  6. from cStringIO import StringIO
  7. from traceback import print_exc
  8. from errno import EWOULDBLOCK, ENOBUFS
  9. try:
  10.     from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
  11.     timemult = 1000
  12. except ImportError:
  13.     from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
  14.     timemult = 1
  15. from threading import Thread, Event
  16. from time import time, sleep
  17. import sys
  18. from random import randrange
  19.  
  20. all = POLLIN | POLLOUT
  21.  
  22. class SingleSocket:
  23.     def __init__(self, raw_server, sock, handler):
  24.         self.raw_server = raw_server
  25.         self.socket = sock
  26.         self.handler = handler
  27.         self.buffer = []
  28.         self.last_hit = time()
  29.         self.fileno = sock.fileno()
  30.         self.connected = False
  31.         
  32.     def get_ip(self):
  33.         try:
  34.             return self.socket.getpeername()[0]
  35.         except socket.error:
  36.             return 'no connection'
  37.         
  38.     def close(self):
  39.         sock = self.socket
  40.         self.socket = None
  41.         self.buffer = []
  42.         del self.raw_server.single_sockets[self.fileno]
  43.         self.raw_server.poll.unregister(sock)
  44.         sock.close()
  45.  
  46.     def shutdown(self, val):
  47.         self.socket.shutdown(val)
  48.  
  49.     def is_flushed(self):
  50.         return len(self.buffer) == 0
  51.  
  52.     def write(self, s):
  53.         assert self.socket is not None
  54.         self.buffer.append(s)
  55.         if len(self.buffer) == 1:
  56.             self.try_write()
  57.  
  58.     def try_write(self):
  59.         if self.connected:
  60.             try:
  61.                 while self.buffer != []:
  62.                     amount = self.socket.send(self.buffer[0])
  63.                     if amount != len(self.buffer[0]):
  64.                         if amount != 0:
  65.                             self.buffer[0] = self.buffer[0][amount:]
  66.                         break
  67.                     del self.buffer[0]
  68.             except socket.error, e:
  69.                 code, msg = e
  70.                 if code != EWOULDBLOCK:
  71.                     self.raw_server.dead_from_write.append(self)
  72.                     return
  73.         if self.buffer == []:
  74.             self.raw_server.poll.register(self.socket, POLLIN)
  75.         else:
  76.             self.raw_server.poll.register(self.socket, all)
  77.  
  78. def default_error_handler(x):
  79.     print x
  80.  
  81. class RawServer:
  82.     def __init__(self, doneflag, timeout_check_interval, timeout, noisy = True,
  83.             errorfunc = default_error_handler, maxconnects = 55):
  84.         self.timeout_check_interval = timeout_check_interval
  85.         self.timeout = timeout
  86.         self.poll = poll()
  87.         # {socket: SingleSocket}
  88.         self.single_sockets = {}
  89.         self.dead_from_write = []
  90.         self.doneflag = doneflag
  91.         self.noisy = noisy
  92.         self.errorfunc = errorfunc
  93.         self.maxconnects = maxconnects
  94.         self.funcs = []
  95.         self.unscheduled_tasks = []
  96.         self.add_task(self.scan_for_timeouts, timeout_check_interval)
  97.  
  98.     def add_task(self, func, delay):
  99.         self.unscheduled_tasks.append((func, delay))
  100.  
  101.     def scan_for_timeouts(self):
  102.         self.add_task(self.scan_for_timeouts, self.timeout_check_interval)
  103.         t = time() - self.timeout
  104.         tokill = []
  105.         for s in self.single_sockets.values():
  106.             if s.last_hit < t:
  107.                 tokill.append(s)
  108.         for k in tokill:
  109.             if k.socket is not None:
  110.                 self._close_socket(k)
  111.  
  112.     def bind(self, port, bind = '', reuse = False):
  113.         self.bindaddr = bind
  114.         server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  115.         if reuse:
  116.             server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  117.         server.setblocking(0)
  118.         try:
  119.             server.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 32)
  120.         except:
  121.             pass
  122.         server.bind((bind, port))
  123.         server.listen(5)
  124.         self.poll.register(server, POLLIN)
  125.         self.server = server
  126.  
  127.     def start_connection(self, dns, handler = None):
  128.         if handler is None:
  129.             handler = self.handler
  130.         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  131.         sock.setblocking(0)
  132.         sock.bind((self.bindaddr, 0))
  133.         try:
  134.             sock.connect_ex(dns)
  135.         except socket.error:
  136.             raise
  137.         except Exception, e:
  138.             raise socket.error(str(e))
  139.         self.poll.register(sock, POLLIN)
  140.         s = SingleSocket(self, sock, handler)
  141.         self.single_sockets[sock.fileno()] = s
  142.         return s
  143.         
  144.     def handle_events(self, events):
  145.         for sock, event in events:
  146.             if sock == self.server.fileno():
  147.                 if event & (POLLHUP | POLLERR) != 0:
  148.                     self.poll.unregister(self.server)
  149.                     self.server.close()
  150.                     self.errorfunc('lost server socket')
  151.                 else:
  152.                     try:
  153.                         newsock, addr = self.server.accept()
  154.                         newsock.setblocking(0)
  155.                         if len(self.single_sockets) >= self.maxconnects:
  156.                             newsock.close()
  157.                             continue
  158.                         nss = SingleSocket(self, newsock, self.handler)
  159.                         self.single_sockets[newsock.fileno()] = nss
  160.                         self.poll.register(newsock, POLLIN)
  161.                         self.handler.external_connection_made(nss)
  162.                     except socket.error:
  163.                         sleep(1)
  164.             else:
  165.                 s = self.single_sockets.get(sock)
  166.                 if s is None:
  167.                     continue
  168.                 s.connected = True
  169.                 if (event & (POLLHUP | POLLERR)) != 0:
  170.                     self._close_socket(s)
  171.                     continue
  172.                 if (event & POLLIN) != 0:
  173.                     try:
  174.                         s.last_hit = time()
  175.                         data = s.socket.recv(100000)
  176.                         if data == '':
  177.                             self._close_socket(s)
  178.                         else:
  179.                             s.handler.data_came_in(s, data)
  180.                     except socket.error, e:
  181.                         code, msg = e
  182.                         if code != EWOULDBLOCK:
  183.                             self._close_socket(s)
  184.                             continue
  185.                 if (event & POLLOUT) != 0 and s.socket is not None and not s.is_flushed():
  186.                     s.try_write()
  187.                     if s.is_flushed():
  188.                         s.handler.connection_flushed(s)
  189.  
  190.     def pop_unscheduled(self):
  191.         try:
  192.             while True:
  193.                 (func, delay) = self.unscheduled_tasks.pop()
  194.                 insort(self.funcs, (time() + delay, func))
  195.         except IndexError:
  196.             pass
  197.  
  198.     def listen_forever(self, handler):
  199.         self.handler = handler
  200.         try:
  201.             while not self.doneflag.isSet():
  202.                 try:
  203.                     self.pop_unscheduled()
  204.                     if len(self.funcs) == 0:
  205.                         period = 2 ** 30
  206.                     else:
  207.                         period = self.funcs[0][0] - time()
  208.                     if period < 0:
  209.                         period = 0
  210.                     events = self.poll.poll(period * timemult)
  211.                     if self.doneflag.isSet():
  212.                         return
  213.                     while len(self.funcs) > 0 and self.funcs[0][0] <= time():
  214.                         garbage, func = self.funcs[0]
  215.                         del self.funcs[0]
  216.                         try:
  217.                             func()
  218.                         except KeyboardInterrupt:
  219.                             #print_exc()
  220.                             return
  221.                         except:
  222.                             if self.noisy:
  223.                                 data = StringIO()
  224.                                 print_exc(file = data)
  225.                                 self.errorfunc(data.getvalue())
  226.                     self._close_dead()
  227.                     self.handle_events(events)
  228.                     if self.doneflag.isSet():
  229.                         return
  230.                     self._close_dead()
  231.                 except error, e:
  232.                     if self.doneflag.isSet():
  233.                         return
  234.                     # I can't find a coherent explanation for what the behavior should be here,
  235.                     # and people report conflicting behavior, so I'll just try all the possibilities
  236.                     try:
  237.                         code, msg, desc = e
  238.                     except:
  239.                         try:
  240.                             code, msg = e
  241.                         except:
  242.                             code = ENOBUFS
  243.                     if code == ENOBUFS:
  244.                         self.errorfunc("Have to exit due to the TCP stack flaking out")
  245.                         return
  246.                 except KeyboardInterrupt:
  247.                     #print_exc()
  248.                     return
  249.                 except:
  250.                     data = StringIO()
  251.                     print_exc(file = data)
  252.                     self.errorfunc(data.getvalue())
  253.         finally:
  254.             for ss in self.single_sockets.values():
  255.                 ss.close()
  256.             self.server.close()
  257.  
  258.     def _close_dead(self):
  259.         while len(self.dead_from_write) > 0:
  260.             old = self.dead_from_write
  261.             self.dead_from_write = []
  262.             for s in old:
  263.                 if s.socket is not None:
  264.                     self._close_socket(s)
  265.  
  266.     def _close_socket(self, s):
  267.         sock = s.socket.fileno()
  268.         s.socket.close()
  269.         self.poll.unregister(sock)
  270.         del self.single_sockets[sock]
  271.         s.socket = None
  272.         s.handler.connection_lost(s)
  273.  
  274. # everything below is for testing
  275.  
  276. class DummyHandler:
  277.     def __init__(self):
  278.         self.external_made = []
  279.         self.data_in = []
  280.         self.lost = []
  281.  
  282.     def external_connection_made(self, s):
  283.         self.external_made.append(s)
  284.     
  285.     def data_came_in(self, s, data):
  286.         self.data_in.append((s, data))
  287.     
  288.     def connection_lost(self, s):
  289.         self.lost.append(s)
  290.  
  291.     def connection_flushed(self, s):
  292.         pass
  293.  
  294. def sl(rs, handler, port):
  295.     rs.bind(port)
  296.     Thread(target = rs.listen_forever, args = [handler]).start()
  297.  
  298. def loop(rs):
  299.     x = []
  300.     def r(rs = rs, x = x):
  301.         rs.add_task(x[0], .1)
  302.     x.append(r)
  303.     rs.add_task(r, .1)
  304.  
  305. beginport = 5000 + randrange(10000)
  306.  
  307. def test_starting_side_close():
  308.     try:
  309.         fa = Event()
  310.         fb = Event()
  311.         da = DummyHandler()
  312.         sa = RawServer(fa, 100, 100)
  313.         loop(sa)
  314.         sl(sa, da, beginport)
  315.         db = DummyHandler()
  316.         sb = RawServer(fb, 100, 100)
  317.         loop(sb)
  318.         sl(sb, db, beginport + 1)
  319.  
  320.         sleep(.5)
  321.         ca = sa.start_connection(('127.0.0.1', beginport + 1))
  322.         sleep(1)
  323.         
  324.         assert da.external_made == []
  325.         assert da.data_in == []
  326.         assert da.lost == []
  327.         assert len(db.external_made) == 1
  328.         cb = db.external_made[0]
  329.         del db.external_made[:]
  330.         assert db.data_in == []
  331.         assert db.lost == []
  332.  
  333.         ca.write('aaa')
  334.         cb.write('bbb')
  335.         sleep(1)
  336.         
  337.         assert da.external_made == []
  338.         assert da.data_in == [(ca, 'bbb')]
  339.         del da.data_in[:]
  340.         assert da.lost == []
  341.         assert db.external_made == []
  342.         assert db.data_in == [(cb, 'aaa')]
  343.         del db.data_in[:]
  344.         assert db.lost == []
  345.  
  346.         ca.write('ccc')
  347.         cb.write('ddd')
  348.         sleep(1)
  349.         
  350.         assert da.external_made == []
  351.         assert da.data_in == [(ca, 'ddd')]
  352.         del da.data_in[:]
  353.         assert da.lost == []
  354.         assert db.external_made == []
  355.         assert db.data_in == [(cb, 'ccc')]
  356.         del db.data_in[:]
  357.         assert db.lost == []
  358.  
  359.         ca.close()
  360.         sleep(1)
  361.  
  362.         assert da.external_made == []
  363.         assert da.data_in == []
  364.         assert da.lost == []
  365.         assert db.external_made == []
  366.         assert db.data_in == []
  367.         assert db.lost == [cb]
  368.         del db.lost[:]
  369.     finally:
  370.         fa.set()
  371.         fb.set()
  372.  
  373. def test_receiving_side_close():
  374.     try:
  375.         da = DummyHandler()
  376.         fa = Event()
  377.         sa = RawServer(fa, 100, 100)
  378.         loop(sa)
  379.         sl(sa, da, beginport + 2)
  380.         db = DummyHandler()
  381.         fb = Event()
  382.         sb = RawServer(fb, 100, 100)
  383.         loop(sb)
  384.         sl(sb, db, beginport + 3)
  385.         
  386.         sleep(.5)
  387.         ca = sa.start_connection(('127.0.0.1', beginport + 3))
  388.         sleep(1)
  389.         
  390.         assert da.external_made == []
  391.         assert da.data_in == []
  392.         assert da.lost == []
  393.         assert len(db.external_made) == 1
  394.         cb = db.external_made[0]
  395.         del db.external_made[:]
  396.         assert db.data_in == []
  397.         assert db.lost == []
  398.  
  399.         ca.write('aaa')
  400.         cb.write('bbb')
  401.         sleep(1)
  402.         
  403.         assert da.external_made == []
  404.         assert da.data_in == [(ca, 'bbb')]
  405.         del da.data_in[:]
  406.         assert da.lost == []
  407.         assert db.external_made == []
  408.         assert db.data_in == [(cb, 'aaa')]
  409.         del db.data_in[:]
  410.         assert db.lost == []
  411.  
  412.         ca.write('ccc')
  413.         cb.write('ddd')
  414.         sleep(1)
  415.         
  416.         assert da.external_made == []
  417.         assert da.data_in == [(ca, 'ddd')]
  418.         del da.data_in[:]
  419.         assert da.lost == []
  420.         assert db.external_made == []
  421.         assert db.data_in == [(cb, 'ccc')]
  422.         del db.data_in[:]
  423.         assert db.lost == []
  424.  
  425.         cb.close()
  426.         sleep(1)
  427.  
  428.         assert da.external_made == []
  429.         assert da.data_in == []
  430.         assert da.lost == [ca]
  431.         del da.lost[:]
  432.         assert db.external_made == []
  433.         assert db.data_in == []
  434.         assert db.lost == []
  435.     finally:
  436.         fa.set()
  437.         fb.set()
  438.  
  439. def test_connection_refused():
  440.     try:
  441.         da = DummyHandler()
  442.         fa = Event()
  443.         sa = RawServer(fa, 100, 100)
  444.         loop(sa)
  445.         sl(sa, da, beginport + 6)
  446.  
  447.         sleep(.5)
  448.         ca = sa.start_connection(('127.0.0.1', beginport + 15))
  449.         sleep(1)
  450.         
  451.         assert da.external_made == []
  452.         assert da.data_in == []
  453.         assert da.lost == [ca]
  454.         del da.lost[:]
  455.     finally:
  456.         fa.set()
  457.  
  458. def test_both_close():
  459.     try:
  460.         da = DummyHandler()
  461.         fa = Event()
  462.         sa = RawServer(fa, 100, 100)
  463.         loop(sa)
  464.         sl(sa, da, beginport + 4)
  465.  
  466.         sleep(1)
  467.         db = DummyHandler()
  468.         fb = Event()
  469.         sb = RawServer(fb, 100, 100)
  470.         loop(sb)
  471.         sl(sb, db, beginport + 5)
  472.  
  473.         sleep(.5)
  474.         ca = sa.start_connection(('127.0.0.1', beginport + 5))
  475.         sleep(1)
  476.         
  477.         assert da.external_made == []
  478.         assert da.data_in == []
  479.         assert da.lost == []
  480.         assert len(db.external_made) == 1
  481.         cb = db.external_made[0]
  482.         del db.external_made[:]
  483.         assert db.data_in == []
  484.         assert db.lost == []
  485.  
  486.         ca.write('aaa')
  487.         cb.write('bbb')
  488.         sleep(1)
  489.         
  490.         assert da.external_made == []
  491.         assert da.data_in == [(ca, 'bbb')]
  492.         del da.data_in[:]
  493.         assert da.lost == []
  494.         assert db.external_made == []
  495.         assert db.data_in == [(cb, 'aaa')]
  496.         del db.data_in[:]
  497.         assert db.lost == []
  498.  
  499.         ca.write('ccc')
  500.         cb.write('ddd')
  501.         sleep(1)
  502.         
  503.         assert da.external_made == []
  504.         assert da.data_in == [(ca, 'ddd')]
  505.         del da.data_in[:]
  506.         assert da.lost == []
  507.         assert db.external_made == []
  508.         assert db.data_in == [(cb, 'ccc')]
  509.         del db.data_in[:]
  510.         assert db.lost == []
  511.  
  512.         ca.close()
  513.         cb.close()
  514.         sleep(1)
  515.  
  516.         assert da.external_made == []
  517.         assert da.data_in == []
  518.         assert da.lost == []
  519.         assert db.external_made == []
  520.         assert db.data_in == []
  521.         assert db.lost == []
  522.     finally:
  523.         fa.set()
  524.         fb.set()
  525.  
  526. def test_normal():
  527.     l = []
  528.     f = Event()
  529.     s = RawServer(f, 100, 100)
  530.     loop(s)
  531.     sl(s, DummyHandler(), beginport + 7)
  532.     s.add_task(lambda l = l: l.append('b'), 2)
  533.     s.add_task(lambda l = l: l.append('a'), 1)
  534.     s.add_task(lambda l = l: l.append('d'), 4)
  535.     sleep(1.5)
  536.     s.add_task(lambda l = l: l.append('c'), 1.5)
  537.     sleep(3)
  538.     assert l == ['a', 'b', 'c', 'd']
  539.     f.set()
  540.  
  541. def test_catch_exception():
  542.     l = []
  543.     f = Event()
  544.     s = RawServer(f, 100, 100, False)
  545.     loop(s)
  546.     sl(s, DummyHandler(), beginport + 9)
  547.     s.add_task(lambda l = l: l.append('b'), 2)
  548.     s.add_task(lambda: 4/0, 1)
  549.     sleep(3)
  550.     assert l == ['b']
  551.     f.set()
  552.  
  553. def test_closes_if_not_hit():
  554.     try:
  555.         da = DummyHandler()
  556.         fa = Event()
  557.         sa = RawServer(fa, 2, 2)
  558.         loop(sa)
  559.         sl(sa, da, beginport + 14)
  560.  
  561.         sleep(1)
  562.         db = DummyHandler()
  563.         fb = Event()
  564.         sb = RawServer(fb, 100, 100)
  565.         loop(sb)
  566.         sl(sb, db, beginport + 13)
  567.         
  568.         sleep(.5)
  569.         sa.start_connection(('127.0.0.1', beginport + 13))
  570.         sleep(1)
  571.         
  572.         assert da.external_made == []
  573.         assert da.data_in == []
  574.         assert da.lost == []
  575.         assert len(db.external_made) == 1
  576.         del db.external_made[:]
  577.         assert db.data_in == []
  578.         assert db.lost == []
  579.  
  580.         sleep(3.1)
  581.         
  582.         assert len(da.lost) == 1
  583.         assert len(db.lost) == 1
  584.     finally:
  585.         fa.set()
  586.         fb.set()
  587.  
  588. def test_does_not_close_if_hit():
  589.     try:
  590.         fa = Event()
  591.         fb = Event()
  592.         da = DummyHandler()
  593.         sa = RawServer(fa, 2, 2)
  594.         loop(sa)
  595.         sl(sa, da, beginport + 12)
  596.  
  597.         sleep(1)
  598.         db = DummyHandler()
  599.         sb = RawServer(fb, 100, 100)
  600.         loop(sb)
  601.         sl(sb, db, beginport + 13)
  602.         
  603.         sleep(.5)
  604.         sa.start_connection(('127.0.0.1', beginport + 13))
  605.         sleep(1)
  606.         
  607.         assert da.external_made == []
  608.         assert da.data_in == []
  609.         assert da.lost == []
  610.         assert len(db.external_made) == 1
  611.         cb = db.external_made[0]
  612.         del db.external_made[:]
  613.         assert db.data_in == []
  614.         assert db.lost == []
  615.  
  616.         cb.write('bbb')
  617.         sleep(.5)
  618.         
  619.         assert da.lost == []
  620.         assert db.lost == []
  621.     finally:
  622.         fa.set()
  623.         fb.set()
  624.